3337e6
@@ -22,6 +22,7 @@
 import java.io.ObjectOutputStream;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
 
 import org.apache.hadoop.hive.common.type.Decimal128;
@@ -38,6 +39,9 @@
 import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryUtils;
 import org.apache.hadoop.hive.serde2.lazybinary.LazyBinarySerDe.StringWrapper;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
 import org.apache.hadoop.hive.serde2.objectinspector.StructField;
 import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 
@@ -57,7 +61,7 @@
public static MapJoinKey read(Output output, MapJoinKey key, MapJoinObjectSerDeC
     Object obj = serde.deserialize(writable);
     boolean useOptimized = useOptimizedKeyBasedOnPrev(key);
     if (useOptimized || key == null) {
-      byte[] structBytes = serializeKey(output, obj, serde.getObjectInspector());
+      byte[] structBytes = serializeKey(output, obj, serde.getObjectInspector(), !useOptimized);
       if (structBytes != null) {
         return MapJoinKeyBytes.fromBytes(key, mayReuseKey, structBytes);
       } else if (useOptimized) {
@@ -70,8 +74,29 @@
public static MapJoinKey read(Output output, MapJoinKey key, MapJoinObjectSerDeC
     return result;
   }
 
-  private static byte[] serializeKey(
-      Output byteStream, Object obj, ObjectInspector oi) throws SerDeException {
+  private static final HashSet<PrimitiveCategory> SUPPORTED_PRIMITIVES
+      = new HashSet<PrimitiveCategory>();
+  static {
+    // All but decimal.
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.BOOLEAN);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.VOID);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.BOOLEAN);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.BYTE);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.SHORT);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.INT);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.LONG);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.FLOAT);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.DOUBLE);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.STRING);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.DATE);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.TIMESTAMP);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.BINARY);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.VARCHAR);
+    SUPPORTED_PRIMITIVES.add(PrimitiveCategory.CHAR);
+  }
+
+  private static byte[] serializeKey(Output byteStream,
+      Object obj, ObjectInspector oi, boolean checkTypes) throws SerDeException {
     if (null == obj || !(oi instanceof StructObjectInspector)) {
       return null; // not supported
     }
@@ -87,8 +112,14 @@
public static MapJoinKey read(Output output, MapJoinKey key, MapJoinObjectSerDeC
     List<ObjectInspector> fieldOis = new ArrayList<ObjectInspector>(size);
     for (int i = 0; i < size; ++i) {
       StructField field = fields.get(i);
+      ObjectInspector foi = field.getFieldObjectInspector();
+      if (checkTypes) {
+        if (foi.getCategory() != Category.PRIMITIVE) return null; // not supported
+        PrimitiveCategory pc = ((PrimitiveObjectInspector)foi).getPrimitiveCategory();
+        if (!SUPPORTED_PRIMITIVES.contains(pc)) return null; // not supported
+      }
       fieldData[i] = soi.getStructFieldData(obj, field);
-      fieldOis.add(field.getFieldObjectInspector());
+      fieldOis.add(foi);
     }
 
     return serializeRowCommon(byteStream, fieldData, fieldOis);
